[MoE] Align Swiglu MXFP4 fused quant paths#3123
Conversation
Remove the GPT-OSS Swiglu layout env switch in favor of GateMode, align the CSV test filter with runtime dtype selection, and restore FlyDSL Swiglu _fp4 fused quant accuracy by matching the non-fused bf16 stage1 semantics. Co-authored-by: Cursor <cursoragent@cursor.com>
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
This PR updates the Swiglu MXFP4 MoE codepaths to remove the legacy GPT-OSS layout environment switch, align runtime q_dtype_a selection with GateMode, and restore FlyDSL fused-quant numerical behavior to match the non-fused bf16 materialization/clamp semantics.
Changes:
- Switch Swiglu MXFP4
q_dtype_aselection to be driven byGateMode.SEPARATEDvs non-separated modes, and threadgate_modethrough the 2-stage config path. - Update CSV-driven MoE 2-stage tests to skip cases whose
q_dtype_ano longer matches the runtime Swiglu MXFP4 selection logic (now includinggate_mode). - Adjust FlyDSL fused quant kernels to apply the Swiglu alpha/clamp path and bf16 round-trip prior to MXFP4 quantization to match the non-fused semantics.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
op_tests/test_moe_2stage.py |
Updates CSV-case filtering to match runtime Swiglu MXFP4 q_dtype_a selection, now factoring in gateMode. |
aiter/ops/flydsl/kernels/silu_and_mul_fq.py |
Aligns fused activation/clamp behavior for Swiglu and adds bf16 round-trip to match non-fused quant semantics. |
aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py |
Adds bf16 materialization before MXFP4 quantization in the fused stage1 store path for Swiglu FP4. |
aiter/fused_moe.py |
Removes the GPT-OSS Swiglu MXFP4 layout env switch and keys runtime dtype selection/config dispatch off gate_mode. |
Comments suppressed due to low confidence (1)
aiter/fused_moe.py:827
get_2stage_cfgs()now acceptsgate_mode, but the tuned-config lookup keys (_INDEX_COLS/keys) do not incorporate it. If SEPARATED vs INTERLEAVE share the sameq_dtype_a/q_dtype_w(e.g. Swiglu MXFP4 small-M where both may be bf16+fp4), this can cause the wrong tuned kernel to be selected or make it impossible to keep separate tuned entries. Consider threadinggate_modethrough the config index (and logging) so the selected kernel is unambiguous across gate layouts.
def get_2stage_cfgs(
token,
model_dim,
inter_dim,
expert,
topk,
dtype,
q_dtype_a,
q_dtype_w,
q_type,
use_g1u1,
activation,
doweight_stage1,
hidden_pad,
intermediate_pad,
is_shuffled=True,
gate_mode=GateMode.SEPARATED.value,
):
gate_mode = GateMode(gate_mode)
_INDEX_COLS = [
"cu_num",
"token",
"model_dim",
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
I think those functions should be merge into this PR: #3129. And let's discuss a more suitable integration solution. |
The MXFP4 W4A16 weight-load path in oracle/mxfp4.py uses shuffle_weight_a16w4 (is_guinterleave=True), which interleaves gate/up columns within each weight tile. The CK/FlyDSL MoE kernels in aiter must be told this via gate_mode=GateMode.INTERLEAVE so they decode the gate/up packing correctly. Without the explicit gate_mode, aiter defaults to SEPARATED and (since ROCm/aiter#3123) dispatches the (SEPARATED + Swiglu + per_1x32 + fp4x2) case to a path that returns garbage for shuffled weights or crashes during CK2stages JIT for the unshuffled Quark variant (amd/gpt-oss-20b-w-mxfp4-a-bf16). This was the root cause of ROCM-25517 (gpt-oss-120b W4A16 gsm8k acc = 0) and ROCM-25478 (gpt-oss-20b Quark JIT crash). Other paths are unaffected: - FP8 W8A8 (DeepSeek-V4-Pro, DeepSeek-V3.2): shuffled with quark_ocp_mx.py:shuffle_weight(layout=(16,16)) — non-interleaved. use_mxfp4_w4a16 is False, default SEPARATED preserved. - MXFP4 W4A4 (amd/DeepSeek-R1-0528-MXFP4): shuffled via rocm_aiter_ops.shuffle_weights — non-interleaved. use_mxfp4_w4a16 is False, default SEPARATED preserved. The gate_mode kwarg was added to aiter.fused_moe in ROCm/aiter#3123 (aiter>=0.1.14). To stay compatible with older aiter shipping with vllm (e.g. aiter 0.1.13.post1 in the vllm-rocm:nightly image), we probe the aiter signature and drop the kwarg when unsupported — pre-vllm-project#3123 aiter tolerated the implicit SEPARATED default for interleave-shuffled weights, so dropping the kwarg is safe there. GateMode itself only exists on aiter>=0.1.14 and is imported under try/except for the same reason. Validation on MI355X (gfx950): vllm@main + aiter@main (6aeba41) openai/gpt-oss-120b W4A16 gsm8k: TP=1: 0.000 -> 0.905 TP=8: 0.000 -> 0.905 vllm@main + aiter@main amd/gpt-oss-20b-w-mxfp4-a-bf16 TP=2 enforce-eager: CK2stages JIT crash -> serves cleanly vllm-rocm:nightly + aiter 0.1.13.post1 openai/gpt-oss-120b W4A16 gsm8k: TP=1: 0.910 (backward-compat — gate_mode kwarg silently dropped) vllm-rocm:v0.22.0 + aiter@main openai/gpt-oss-120b W4A16 gsm8k: TP=1: 0.895 amd/gpt-oss120b-w-mxfp4-a-fp8 W4A8 (this PR composes with vllm-project#44804): TP=8 mc=1=326, mc=8=2087, mc=32=6523, mc=64=11610 tok/s Reference: sgl-project/sglang#25580 (sglang's equivalent fix). Recommended by aiter maintainer (XiaobingZhang) on ROCm/aiter#3586. Signed-off-by: Rohan Potdar <rohan.potdar@amd.com>
Summary
test_moe_2stage.pyreferences with runtime Swiglu MXFP4 fused quant semantics by using an f32 stage1 reference for FP4 fused-quant cases.gateModefrom dtype/layout because tuned rows do not carry an explicitgateModefield.Test plan
podman exec zxb_vllm_gptoss bash -lc 'cd /workdir/aiter_main && python3 -m py_compile op_tests/test_moe_2stage.py aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py aiter/ops/flydsl/kernels/silu_and_mul_fq.py && git diff --check'podman exec zxb_vllm_gptoss bash -lc 'cd /workdir/aiter_main && HIP_VISIBLE_DEVICES=1 FLYDSL_RUNTIME_CACHE_DIR=/tmp/flydsl_pr3123_test_fp4 AITER_CONFIG_FMOE=/workdir/aiter_main/aiter/configs/model_configs/gptoss_fp4_tuned_fmoe.csv python3 -m op_tests.test_moe_2stage --no-legacy'podman exec zxb_vllm_gptoss bash -lc 'cd /workdir/aiter_main && HIP_VISIBLE_DEVICES=1 FLYDSL_RUNTIME_CACHE_DIR=/tmp/flydsl_pr3123_test_fp8fp4 AITER_CONFIG_FMOE=/workdir/aiter_main/aiter/configs/model_configs/gptoss_fp8fp4_tuned_fmoe.csv python3 -m op_tests.test_moe_2stage --no-legacy'podman exec zxb_vllm_gptoss bash -lc 'cd /workdir/aiter_main && HIP_VISIBLE_DEVICES=1 FLYDSL_RUNTIME_CACHE_DIR=/tmp/flydsl_pr3123_test_legacy python3 -m op_tests.test_moe_2stage --no-flydsl-csv -t 1024 -dim 3072,3072 -e 128 -k 4 -q 4 -a swiglu -s f -p t -hip 0,0'Test result
gptoss_fp4_tuned_fmoe.csv --no-legacy: passed 8 strict CSV cases, command exit code 0.gptoss_fp8fp4_tuned_fmoe.csv --no-legacy: passed 7 strict CSV cases, command exit code 0.Made with Cursor